Skip to content

Conversation

@csukuangfj
Copy link
Collaborator

See also lhotse-speech/lhotse#1477

(I prefer to put it inside icefall. If moving to lhotse is desired, I can do that.)

Test code

#!/usr/bin/env python3

from functools import partial

import lhotse
from lhotse import Fbank, FbankConfig
from lhotse.dataset import SimpleCutSampler
from lhotse.dataset.input_strategies import (
    AudioSamples,
    OnTheFlyFeatures,
)
from torch.utils.data.dataloader import DataLoader

from speech_recognition_dataset import ConsistencyRegularizationSpeechRecognitionDataset


def create_cutset():
    recording0 = lhotse.Recording.from_file("./0.wav")
    sup0 = lhotse.SupervisionSegment(
        id="sup0",
        recording_id=recording0.id,
        start=0,
        duration=5,
        text="hello, how are you",
    )
    cut0 = lhotse.MonoCut(
        id="cut0",
        start=0,
        duration=6,
        channel=0,
        recording=recording0,
        supervisions=[sup0],
    )

    recording1 = lhotse.Recording.from_file("./1.wav")
    sup1 = lhotse.SupervisionSegment(
        id="sup1",
        recording_id=recording1.id,
        start=0,
        duration=3,
        text="fine, thank you",
    )
    cut1 = lhotse.MonoCut(
        id="cut1",
        start=0,
        duration=4,
        channel=0,
        recording=recording1,
        supervisions=[sup1],
    )

    return lhotse.CutSet([cut0, cut1])


def main():
    cutset = create_cutset()
    print([c for c in cutset])

    t1 = partial(lhotse.MonoCut.perturb_speed, factor=0.9)
    t2 = partial(lhotse.MonoCut.perturb_volume, factor=1.1)
    t3 = partial(lhotse.MonoCut.perturb_tempo, factor=1.2)
    transforms = [t1, t2, t3]

    train = ConsistencyRegularizationSpeechRecognitionDataset(
        input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
        #  input_strategy=AudioSamples(),
        cut_transforms=transforms,
        #  return_cuts=True,
        return_cuts=False,
    )
    print(train)
    train_sampler = SimpleCutSampler(
        cutset,
        max_duration=20,
        shuffle=False,
    )

    train_dl = DataLoader(
        train,
        sampler=train_sampler,
        batch_size=None,
        num_workers=1,
        persistent_workers=False,
    )
    for b in train_dl:
        print(b["inputs"].shape, b["supervisions"])
        if "aug" in b:
            assert len(b["aug"]) == len(transforms)
            for i, aug in enumerate(b["aug"]):
                print(
                    transforms[i].func.__name__,
                    aug["inputs"].shape,
                    aug["supervisions"],
                )


if __name__ == "__main__":
    main()

The output is given below:

[MonoCut(id='cut0', start=0, duration=6, channel=0, supervisions=[SupervisionSegment(id='sup0', recording_id='0', start=0, duration=5, channel=0, text='hello, how are you', language=None, speaker=None, gender=None, custom=None, alignment=None)], features=None, recording=Recording(id='0', sources=[AudioSource(type='file', channels=[0], source='0.wav')], sampling_rate=16000, num_samples=106000, duration=6.625, channel_ids=[0], transforms=None), custom=None), MonoCut(id='cut1', start=0, duration=4, channel=0, supervisions=[SupervisionSegment(id='sup1', recording_id='1', start=0, duration=3, channel=0, text='fine, thank you', language=None, speaker=None, gender=None, custom=None, alignment=None)], features=None, recording=Recording(id='1', sources=[AudioSource(type='file', channels=[0], source='1.wav')], sampling_rate=16000, num_samples=81600, duration=5.1, channel_ids=[0], transforms=None), custom=None)]
<speech_recognition_dataset.ConsistencyRegularizationSpeechRecognitionDataset object at 0x119f1c8e0>
torch.Size([2, 600, 80]) {'text': ['hello, how are you', 'fine, thank you'], 'sequence_idx': tensor([0, 1], dtype=torch.int32), 'start_frame': tensor([0, 0], dtype=torch.int32), 'num_frames': tensor([500, 300], dtype=torch.int32)}
perturb_speed torch.Size([2, 667, 80]) {'text': ['hello, how are you', 'fine, thank you'], 'sequence_idx': tensor([0, 1], dtype=torch.int32), 'start_frame': tensor([0, 0], dtype=torch.int32), 'num_frames': tensor([556, 333], dtype=torch.int32)}
perturb_volume torch.Size([2, 600, 80]) {'text': ['hello, how are you', 'fine, thank you'], 'sequence_idx': tensor([0, 1], dtype=torch.int32), 'start_frame': tensor([0, 0], dtype=torch.int32), 'num_frames': tensor([500, 300], dtype=torch.int32)}
perturb_tempo torch.Size([2, 500, 80]) {'text': ['hello, how are you', 'fine, thank you'], 'sequence_idx': tensor([0, 1], dtype=torch.int32), 'start_frame': tensor([0, 0], dtype=torch.int32), 'num_frames': tensor([417, 250], dtype=torch.int32)}

@pzelasko
Copy link
Collaborator

Look good to me, +1 for keeping this code in Icefall, I think it’s more convenient to have it close to the training recipe.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants